{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# Torch-TensorRT Getting Started - CitriNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "[Citrinet](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/models.html#citrinet) is an acoustic model used for the speech to text recognition task. It is a version of [QuartzNet](https://arxiv.org/pdf/1910.10261.pdf) that extends [ContextNet](https://arxiv.org/pdf/2005.03191.pdf), utilizing subword encoding (via Word Piece tokenization) and Squeeze-and-Excitation(SE) mechanism and are therefore smaller than QuartzNet models.\n",
    "\n",
    "CitriNet models take in audio segments and transcribe them to letter, byte pair, or word piece sequences. \n",
    "\n",
    "<img src=\"https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/_images/jasper_vertical.png\" alt=\"alt\" width=\"50%\"/>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning objectives\n",
    "\n",
    "This notebook demonstrates the steps for optimizing a pretrained CitriNet model with Torch-TensorRT, and running it to test the speedup obtained.\n",
    "\n",
    "## Content\n",
    "1. [Requirements](#1)\n",
    "1. [Download Citrinet model](#2)\n",
    "1. [Create Torch-TensorRT modules](#3)\n",
    "1. [Benchmark Torch-TensorRT models](#4)\n",
    "1. [Conclusion](#5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"1\"></a>\n",
    "## 1. Requirements\n",
    "\n",
    "Follow the steps in [README](README.md) to prepare a Docker container, within which you can run this notebook. \n",
    "This notebook assumes that you are within a Jupyter environment in a docker container with Torch-TensorRT installed, such as an NGC monthly release of `nvcr.io/nvidia/pytorch:<yy.mm>-py3` (where `yy` indicates the last two numbers of a calendar year, and `mm` indicates the month in two-digit numerical form)\n",
    "\n",
    "Now that you are in the docker, the next step is to install the required dependencies."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: wget in /opt/conda/lib/python3.8/site-packages (3.2)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
      "Hit:1 http://security.ubuntu.com/ubuntu focal-security InRelease\n",
      "Hit:2 http://archive.ubuntu.com/ubuntu focal InRelease\n",
      "Hit:3 http://archive.ubuntu.com/ubuntu focal-updates InRelease\n",
      "Hit:4 http://archive.ubuntu.com/ubuntu focal-backports InRelease\n",
      "Reading package lists... Done\n",
      "Reading package lists... Done\n",
      "Building dependency tree       \n",
      "Reading state information... Done\n",
      "libsndfile1 is already the newest version (1.0.28-7ubuntu0.1).\n",
      "ffmpeg is already the newest version (7:4.2.4-1ubuntu0.1).\n",
      "0 upgraded, 0 newly installed, 0 to remove and 22 not upgraded.\n",
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: Cython in /opt/conda/lib/python3.8/site-packages (0.29.28)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n",
      "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
      "Requirement already satisfied: nemo_toolkit[all]==1.5.1 in /opt/conda/lib/python3.8/site-packages (1.5.1)\n",
      "Requirement already satisfied: numpy>=1.18.2 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.22.3)\n",
      "Requirement already satisfied: onnx>=1.7.0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.10.1)\n",
      "Requirement already satisfied: python-dateutil in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.8.2)\n",
      "Requirement already satisfied: tqdm>=4.41.0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (4.63.0)\n",
      "Requirement already satisfied: sentencepiece<1.0.0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.1.96)\n",
      "Requirement already satisfied: wget in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (3.2)\n",
      "Requirement already satisfied: numba in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.53.1)\n",
      "Requirement already satisfied: torch in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.12.0a0+2c916ef)\n",
      "Requirement already satisfied: unidecode in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.3.4)\n",
      "Requirement already satisfied: frozendict in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.3.2)\n",
      "Requirement already satisfied: wrapt in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.14.0)\n",
      "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.24.2)\n",
      "Requirement already satisfied: ruamel.yaml in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.17.21)\n",
      "Requirement already satisfied: pesq in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.0.3)\n",
      "Requirement already satisfied: torchvision in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.13.0a0)\n",
      "Requirement already satisfied: gdown in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (4.4.0)\n",
      "Requirement already satisfied: editdistance in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.6.0)\n",
      "Requirement already satisfied: boto3 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.21.45)\n",
      "Requirement already satisfied: isort[requirements]<5 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (4.3.21)\n",
      "Requirement already satisfied: hydra-core>=1.1.0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.1.2)\n",
      "Requirement already satisfied: youtokentome>=1.0.5 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.0.6)\n",
      "Requirement already satisfied: pytorch-lightning>=1.5.0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.6.1)\n",
      "Requirement already satisfied: jieba in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.42.1)\n",
      "Requirement already satisfied: fasttext in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.9.2)\n",
      "Requirement already satisfied: soundfile in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.10.3.post1)\n",
      "Requirement already satisfied: kaldiio in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.17.2)\n",
      "Requirement already satisfied: pangu in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (4.0.6.1)\n",
      "Requirement already satisfied: kaldi-python-io in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.2.2)\n",
      "Requirement already satisfied: parameterized in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.8.1)\n",
      "Requirement already satisfied: h5py in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (3.6.0)\n",
      "Requirement already satisfied: rapidfuzz in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.0.10)\n",
      "Requirement already satisfied: marshmallow in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (3.15.0)\n",
      "Requirement already satisfied: opencc in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.1.3)\n",
      "Requirement already satisfied: braceexpand in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.1.7)\n",
      "Requirement already satisfied: omegaconf>=2.1.0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.1.2)\n",
      "Requirement already satisfied: sphinx in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (4.4.0)\n",
      "Requirement already satisfied: pillow in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (9.0.0)\n",
      "Requirement already satisfied: wordninja==2.0.0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.0.0)\n",
      "Requirement already satisfied: torch-stft in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.1.4)\n",
      "Requirement already satisfied: sox in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.4.1)\n",
      "Requirement already satisfied: librosa in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.9.1)\n",
      "Requirement already satisfied: regex in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2022.3.15)\n",
      "Requirement already satisfied: sacrebleu[ja] in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.0.0)\n",
      "Requirement already satisfied: black==19.10b0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (19.10b0)\n",
      "Requirement already satisfied: pydub in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.25.1)\n",
      "Requirement already satisfied: sphinxcontrib-bibtex in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.4.2)\n",
      "Requirement already satisfied: inflect in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (5.5.2)\n",
      "Requirement already satisfied: pyannote.core in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (4.4)\n",
      "Requirement already satisfied: packaging in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (21.3)\n",
      "Requirement already satisfied: kaldi-io in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.9.4)\n",
      "Requirement already satisfied: pyannote.metrics in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (3.2)\n",
      "Requirement already satisfied: g2p-en in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.1.0)\n",
      "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (3.5.1)\n",
      "Requirement already satisfied: torchmetrics>=0.4.1rc0 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.8.0)\n",
      "Requirement already satisfied: nltk>=3.6.5 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (3.7)\n",
      "Requirement already satisfied: pyyaml<6 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (5.4.1)\n",
      "Requirement already satisfied: scipy in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.6.3)\n",
      "Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (7.7.0)\n",
      "Requirement already satisfied: pytest in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (6.2.5)\n",
      "Requirement already satisfied: pandas in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (1.3.5)\n",
      "Requirement already satisfied: pytest-runner in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (6.0.0)\n",
      "Requirement already satisfied: transformers>=4.0.1 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (4.18.0)\n",
      "Requirement already satisfied: sacremoses>=0.0.43 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.0.49)\n",
      "Requirement already satisfied: pystoi in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.3.3)\n",
      "Requirement already satisfied: attrdict in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (2.0.1)\n",
      "Requirement already satisfied: webdataset<=0.1.62,>=0.1.48 in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.1.62)\n",
      "Requirement already satisfied: wandb in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.12.15)\n",
      "Requirement already satisfied: pypinyin in /opt/conda/lib/python3.8/site-packages (from nemo_toolkit[all]==1.5.1) (0.46.0)\n",
      "Requirement already satisfied: attrs>=18.1.0 in /opt/conda/lib/python3.8/site-packages (from black==19.10b0->nemo_toolkit[all]==1.5.1) (21.4.0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: appdirs in /opt/conda/lib/python3.8/site-packages (from black==19.10b0->nemo_toolkit[all]==1.5.1) (1.4.4)\n",
      "Requirement already satisfied: typed-ast>=1.4.0 in /opt/conda/lib/python3.8/site-packages (from black==19.10b0->nemo_toolkit[all]==1.5.1) (1.5.3)\n",
      "Requirement already satisfied: pathspec<1,>=0.6 in /opt/conda/lib/python3.8/site-packages (from black==19.10b0->nemo_toolkit[all]==1.5.1) (0.9.0)\n",
      "Requirement already satisfied: click>=6.5 in /opt/conda/lib/python3.8/site-packages (from black==19.10b0->nemo_toolkit[all]==1.5.1) (8.0.4)\n",
      "Requirement already satisfied: toml>=0.9.4 in /opt/conda/lib/python3.8/site-packages (from black==19.10b0->nemo_toolkit[all]==1.5.1) (0.10.2)\n",
      "Requirement already satisfied: antlr4-python3-runtime==4.8 in /opt/conda/lib/python3.8/site-packages (from hydra-core>=1.1.0->nemo_toolkit[all]==1.5.1) (4.8)\n",
      "Requirement already satisfied: importlib-resources<5.3 in /opt/conda/lib/python3.8/site-packages (from hydra-core>=1.1.0->nemo_toolkit[all]==1.5.1) (5.2.3)\n",
      "Requirement already satisfied: zipp>=3.1.0 in /opt/conda/lib/python3.8/site-packages (from importlib-resources<5.3->hydra-core>=1.1.0->nemo_toolkit[all]==1.5.1) (3.7.0)\n",
      "Requirement already satisfied: pip-api in /opt/conda/lib/python3.8/site-packages (from isort[requirements]<5->nemo_toolkit[all]==1.5.1) (0.0.29)\n",
      "Requirement already satisfied: pipreqs in /opt/conda/lib/python3.8/site-packages (from isort[requirements]<5->nemo_toolkit[all]==1.5.1) (0.4.11)\n",
      "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.8/site-packages (from matplotlib->nemo_toolkit[all]==1.5.1) (4.31.2)\n",
      "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.8/site-packages (from matplotlib->nemo_toolkit[all]==1.5.1) (1.4.0)\n",
      "Requirement already satisfied: pyparsing>=2.2.1 in /opt/conda/lib/python3.8/site-packages (from matplotlib->nemo_toolkit[all]==1.5.1) (3.0.7)\n",
      "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.8/site-packages (from matplotlib->nemo_toolkit[all]==1.5.1) (0.11.0)\n",
      "Requirement already satisfied: joblib in /opt/conda/lib/python3.8/site-packages (from nltk>=3.6.5->nemo_toolkit[all]==1.5.1) (1.1.0)\n",
      "Requirement already satisfied: typing-extensions>=3.6.2.1 in /opt/conda/lib/python3.8/site-packages (from onnx>=1.7.0->nemo_toolkit[all]==1.5.1) (4.1.1)\n",
      "Requirement already satisfied: six in /opt/conda/lib/python3.8/site-packages (from onnx>=1.7.0->nemo_toolkit[all]==1.5.1) (1.16.0)\n",
      "Requirement already satisfied: protobuf>=3.12.2 in /opt/conda/lib/python3.8/site-packages (from onnx>=1.7.0->nemo_toolkit[all]==1.5.1) (3.19.4)\n",
      "Requirement already satisfied: pyDeprecate<0.4.0,>=0.3.1 in /opt/conda/lib/python3.8/site-packages (from pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (0.3.2)\n",
      "Requirement already satisfied: tensorboard>=2.2.0 in /opt/conda/lib/python3.8/site-packages (from pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (2.8.0)\n",
      "Requirement already satisfied: fsspec[http]!=2021.06.0,>=2021.05.0 in /opt/conda/lib/python3.8/site-packages (from pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (2022.2.0)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.8/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (2.27.1)\n",
      "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.8/site-packages (from fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (3.8.1)\n",
      "Requirement already satisfied: werkzeug>=0.11.15 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (2.0.3)\n",
      "Requirement already satisfied: markdown>=2.6.8 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (3.3.6)\n",
      "Requirement already satisfied: setuptools>=41.0.0 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (59.5.0)\n",
      "Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (0.4.6)\n",
      "Requirement already satisfied: google-auth<3,>=1.6.3 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (2.6.2)\n",
      "Requirement already satisfied: wheel>=0.26 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (0.37.1)\n",
      "Requirement already satisfied: grpcio>=1.24.3 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.44.0)\n",
      "Requirement already satisfied: tensorboard-data-server<0.7.0,>=0.6.0 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (0.6.1)\n",
      "Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.8.1)\n",
      "Requirement already satisfied: absl-py>=0.4 in /opt/conda/lib/python3.8/site-packages (from tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.0.0)\n",
      "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (5.0.0)\n",
      "Requirement already satisfied: pyasn1-modules>=0.2.1 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (0.2.8)\n",
      "Requirement already satisfied: rsa<5,>=3.1.4 in /opt/conda/lib/python3.8/site-packages (from google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (4.8)\n",
      "Requirement already satisfied: requests-oauthlib>=0.7.0 in /opt/conda/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.3.1)\n",
      "Requirement already satisfied: importlib-metadata>=4.4 in /opt/conda/lib/python3.8/site-packages (from markdown>=2.6.8->tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (4.11.3)\n",
      "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /opt/conda/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (0.4.8)\n",
      "Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/conda/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (2.0.12)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (2021.10.8)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.26.8)\n",
      "Requirement already satisfied: oauthlib>=3.0.0 in /opt/conda/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.2.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (3.2.0)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.1.0 in /opt/conda/lib/python3.8/site-packages (from transformers>=4.0.1->nemo_toolkit[all]==1.5.1) (0.5.1)\n",
      "Requirement already satisfied: tokenizers!=0.11.3,<0.13,>=0.11.1 in /opt/conda/lib/python3.8/site-packages (from transformers>=4.0.1->nemo_toolkit[all]==1.5.1) (0.12.1)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.8/site-packages (from transformers>=4.0.1->nemo_toolkit[all]==1.5.1) (3.6.0)\n",
      "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.3.0)\n",
      "Requirement already satisfied: yarl<2.0,>=1.0 in /opt/conda/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.7.2)\n",
      "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /opt/conda/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (4.0.2)\n",
      "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (6.0.2)\n",
      "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.8/site-packages (from aiohttp->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.2.0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: s3transfer<0.6.0,>=0.5.0 in /opt/conda/lib/python3.8/site-packages (from boto3->nemo_toolkit[all]==1.5.1) (0.5.2)\n",
      "Requirement already satisfied: botocore<1.25.0,>=1.24.45 in /opt/conda/lib/python3.8/site-packages (from boto3->nemo_toolkit[all]==1.5.1) (1.24.45)\n",
      "Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /opt/conda/lib/python3.8/site-packages (from boto3->nemo_toolkit[all]==1.5.1) (1.0.0)\n",
      "Requirement already satisfied: pybind11>=2.2 in /opt/conda/lib/python3.8/site-packages (from fasttext->nemo_toolkit[all]==1.5.1) (2.9.1)\n",
      "Requirement already satisfied: distance>=0.1.3 in /opt/conda/lib/python3.8/site-packages (from g2p-en->nemo_toolkit[all]==1.5.1) (0.1.3)\n",
      "Requirement already satisfied: beautifulsoup4 in /opt/conda/lib/python3.8/site-packages (from gdown->nemo_toolkit[all]==1.5.1) (4.10.0)\n",
      "Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.8/site-packages (from beautifulsoup4->gdown->nemo_toolkit[all]==1.5.1) (2.3.1)\n",
      "Requirement already satisfied: ipython-genutils~=0.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets->nemo_toolkit[all]==1.5.1) (0.2.0)\n",
      "Requirement already satisfied: ipython>=4.0.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets->nemo_toolkit[all]==1.5.1) (8.1.1)\n",
      "Requirement already satisfied: ipykernel>=4.5.1 in /opt/conda/lib/python3.8/site-packages (from ipywidgets->nemo_toolkit[all]==1.5.1) (6.9.2)\n",
      "Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets->nemo_toolkit[all]==1.5.1) (1.1.0)\n",
      "Requirement already satisfied: widgetsnbextension~=3.6.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets->nemo_toolkit[all]==1.5.1) (3.6.0)\n",
      "Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.8/site-packages (from ipywidgets->nemo_toolkit[all]==1.5.1) (5.1.1)\n",
      "Requirement already satisfied: nbformat>=4.2.0 in /opt/conda/lib/python3.8/site-packages (from ipywidgets->nemo_toolkit[all]==1.5.1) (5.2.0)\n",
      "Requirement already satisfied: jupyter-client<8.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (7.1.2)\n",
      "Requirement already satisfied: psutil in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (5.9.0)\n",
      "Requirement already satisfied: tornado<7.0,>=4.2 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (6.1)\n",
      "Requirement already satisfied: debugpy<2.0,>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (1.5.1)\n",
      "Requirement already satisfied: nest-asyncio in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (1.5.4)\n",
      "Requirement already satisfied: matplotlib-inline<0.2.0,>=0.1.0 in /opt/conda/lib/python3.8/site-packages (from ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (0.1.3)\n",
      "Requirement already satisfied: pickleshare in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.7.5)\n",
      "Requirement already satisfied: decorator in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (5.1.1)\n",
      "Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (2.11.2)\n",
      "Requirement already satisfied: stack-data in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.2.0)\n",
      "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (3.0.27)\n",
      "Requirement already satisfied: backcall in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.2.0)\n",
      "Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.18.1)\n",
      "Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.8/site-packages (from ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (4.8.0)\n",
      "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /opt/conda/lib/python3.8/site-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.8.3)\n",
      "Requirement already satisfied: entrypoints in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (0.4)\n",
      "Requirement already satisfied: pyzmq>=13 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (22.3.0)\n",
      "Requirement already satisfied: jupyter-core>=4.6.0 in /opt/conda/lib/python3.8/site-packages (from jupyter-client<8.0->ipykernel>=4.5.1->ipywidgets->nemo_toolkit[all]==1.5.1) (4.9.2)\n",
      "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /opt/conda/lib/python3.8/site-packages (from nbformat>=4.2.0->ipywidgets->nemo_toolkit[all]==1.5.1) (4.4.0)\n",
      "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /opt/conda/lib/python3.8/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.18.1)\n",
      "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.8/site-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.7.0)\n",
      "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.8/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.2.5)\n",
      "Requirement already satisfied: notebook>=4.4.1 in /opt/conda/lib/python3.8/site-packages (from widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (6.4.1)\n",
      "Requirement already satisfied: Send2Trash>=1.5.0 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (1.8.0)\n",
      "Requirement already satisfied: prometheus-client in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.13.1)\n",
      "Requirement already satisfied: terminado>=0.8.3 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.13.3)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (3.0.3)\n",
      "Requirement already satisfied: nbconvert in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (6.4.4)\n",
      "Requirement already satisfied: argon2-cffi in /opt/conda/lib/python3.8/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (21.3.0)\n",
      "Requirement already satisfied: argon2-cffi-bindings in /opt/conda/lib/python3.8/site-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (21.2.0)\n",
      "Requirement already satisfied: cffi>=1.0.1 in /opt/conda/lib/python3.8/site-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (1.15.0)\n",
      "Requirement already satisfied: pycparser in /opt/conda/lib/python3.8/site-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (2.21)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.8/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (2.1.1)\n",
      "Requirement already satisfied: resampy>=0.2.2 in /opt/conda/lib/python3.8/site-packages (from librosa->nemo_toolkit[all]==1.5.1) (0.2.2)\n",
      "Requirement already satisfied: pooch>=1.0 in /opt/conda/lib/python3.8/site-packages (from librosa->nemo_toolkit[all]==1.5.1) (1.6.0)\n",
      "Requirement already satisfied: audioread>=2.1.5 in /opt/conda/lib/python3.8/site-packages (from librosa->nemo_toolkit[all]==1.5.1) (2.1.9)\n",
      "Requirement already satisfied: llvmlite<0.37,>=0.36.0rc1 in /opt/conda/lib/python3.8/site-packages (from numba->nemo_toolkit[all]==1.5.1) (0.36.0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from scikit-learn->nemo_toolkit[all]==1.5.1) (3.1.0)\n",
      "Requirement already satisfied: defusedxml in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.7.1)\n",
      "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.5.13)\n",
      "Requirement already satisfied: bleach in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (4.1.0)\n",
      "Requirement already satisfied: mistune<2,>=0.8.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.8.4)\n",
      "Requirement already satisfied: pandocfilters>=1.4.1 in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (1.5.0)\n",
      "Requirement already satisfied: testpath in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.6.0)\n",
      "Requirement already satisfied: jupyterlab-pygments in /opt/conda/lib/python3.8/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.1.2)\n",
      "Requirement already satisfied: webencodings in /opt/conda/lib/python3.8/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.5.1)\n",
      "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.8/site-packages (from pandas->nemo_toolkit[all]==1.5.1) (2021.3)\n",
      "Requirement already satisfied: pip in /opt/conda/lib/python3.8/site-packages (from pip-api->isort[requirements]<5->nemo_toolkit[all]==1.5.1) (21.2.4)\n",
      "Requirement already satisfied: yarg in /opt/conda/lib/python3.8/site-packages (from pipreqs->isort[requirements]<5->nemo_toolkit[all]==1.5.1) (0.1.9)\n",
      "Requirement already satisfied: docopt in /opt/conda/lib/python3.8/site-packages (from pipreqs->isort[requirements]<5->nemo_toolkit[all]==1.5.1) (0.6.2)\n",
      "Requirement already satisfied: simplejson>=3.8.1 in /opt/conda/lib/python3.8/site-packages (from pyannote.core->nemo_toolkit[all]==1.5.1) (3.17.6)\n",
      "Requirement already satisfied: sortedcontainers>=2.0.4 in /opt/conda/lib/python3.8/site-packages (from pyannote.core->nemo_toolkit[all]==1.5.1) (2.4.0)\n",
      "Requirement already satisfied: tabulate>=0.7.7 in /opt/conda/lib/python3.8/site-packages (from pyannote.metrics->nemo_toolkit[all]==1.5.1) (0.8.9)\n",
      "Requirement already satisfied: pyannote.database>=4.0.1 in /opt/conda/lib/python3.8/site-packages (from pyannote.metrics->nemo_toolkit[all]==1.5.1) (4.1.3)\n",
      "Requirement already satisfied: sympy>=1.1 in /opt/conda/lib/python3.8/site-packages (from pyannote.metrics->nemo_toolkit[all]==1.5.1) (1.10.1)\n",
      "Requirement already satisfied: typer[all]>=0.2.1 in /opt/conda/lib/python3.8/site-packages (from pyannote.database>=4.0.1->pyannote.metrics->nemo_toolkit[all]==1.5.1) (0.4.0)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.8/site-packages (from sympy>=1.1->pyannote.metrics->nemo_toolkit[all]==1.5.1) (1.2.1)\n",
      "Requirement already satisfied: colorama<0.5.0,>=0.4.3 in /opt/conda/lib/python3.8/site-packages (from typer[all]>=0.2.1->pyannote.database>=4.0.1->pyannote.metrics->nemo_toolkit[all]==1.5.1) (0.4.4)\n",
      "Requirement already satisfied: shellingham<2.0.0,>=1.3.0 in /opt/conda/lib/python3.8/site-packages (from typer[all]>=0.2.1->pyannote.database>=4.0.1->pyannote.metrics->nemo_toolkit[all]==1.5.1) (1.4.0)\n",
      "Requirement already satisfied: py>=1.8.2 in /opt/conda/lib/python3.8/site-packages (from pytest->nemo_toolkit[all]==1.5.1) (1.11.0)\n",
      "Requirement already satisfied: iniconfig in /opt/conda/lib/python3.8/site-packages (from pytest->nemo_toolkit[all]==1.5.1) (1.1.1)\n",
      "Requirement already satisfied: pluggy<2.0,>=0.12 in /opt/conda/lib/python3.8/site-packages (from pytest->nemo_toolkit[all]==1.5.1) (1.0.0)\n",
      "Requirement already satisfied: jarowinkler<1.1.0,>=1.0.2 in /opt/conda/lib/python3.8/site-packages (from rapidfuzz->nemo_toolkit[all]==1.5.1) (1.0.2)\n",
      "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /opt/conda/lib/python3.8/site-packages (from requests->fsspec[http]!=2021.06.0,>=2021.05.0->pytorch-lightning>=1.5.0->nemo_toolkit[all]==1.5.1) (1.7.1)\n",
      "Requirement already satisfied: ruamel.yaml.clib>=0.2.6 in /opt/conda/lib/python3.8/site-packages (from ruamel.yaml->nemo_toolkit[all]==1.5.1) (0.2.6)\n",
      "Requirement already satisfied: portalocker in /opt/conda/lib/python3.8/site-packages (from sacrebleu[ja]->nemo_toolkit[all]==1.5.1) (2.4.0)\n",
      "Requirement already satisfied: ipadic<2.0,>=1.0 in /opt/conda/lib/python3.8/site-packages (from sacrebleu[ja]->nemo_toolkit[all]==1.5.1) (1.0.0)\n",
      "Requirement already satisfied: mecab-python3==1.0.3 in /opt/conda/lib/python3.8/site-packages (from sacrebleu[ja]->nemo_toolkit[all]==1.5.1) (1.0.3)\n",
      "Requirement already satisfied: sphinxcontrib-htmlhelp>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (2.0.0)\n",
      "Requirement already satisfied: alabaster<0.8,>=0.7 in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (0.7.12)\n",
      "Requirement already satisfied: babel>=1.3 in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (2.9.1)\n",
      "Requirement already satisfied: sphinxcontrib-serializinghtml>=1.1.5 in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (1.1.5)\n",
      "Requirement already satisfied: sphinxcontrib-devhelp in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (1.0.2)\n",
      "Requirement already satisfied: sphinxcontrib-jsmath in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (1.0.1)\n",
      "Requirement already satisfied: sphinxcontrib-qthelp in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (1.0.3)\n",
      "Requirement already satisfied: snowballstemmer>=1.1 in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (2.2.0)\n",
      "Requirement already satisfied: imagesize in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (1.3.0)\n",
      "Requirement already satisfied: sphinxcontrib-applehelp in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (1.0.2)\n",
      "Requirement already satisfied: docutils<0.18,>=0.14 in /opt/conda/lib/python3.8/site-packages (from sphinx->nemo_toolkit[all]==1.5.1) (0.17.1)\n",
      "Requirement already satisfied: pybtex-docutils>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from sphinxcontrib-bibtex->nemo_toolkit[all]==1.5.1) (1.0.1)\n",
      "Requirement already satisfied: pybtex>=0.24 in /opt/conda/lib/python3.8/site-packages (from sphinxcontrib-bibtex->nemo_toolkit[all]==1.5.1) (0.24.0)\n",
      "Requirement already satisfied: latexcodec>=1.0.4 in /opt/conda/lib/python3.8/site-packages (from pybtex>=0.24->sphinxcontrib-bibtex->nemo_toolkit[all]==1.5.1) (2.0.1)\n",
      "Requirement already satisfied: pure-eval in /opt/conda/lib/python3.8/site-packages (from stack-data->ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.2.2)\n",
      "Requirement already satisfied: asttokens in /opt/conda/lib/python3.8/site-packages (from stack-data->ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (2.0.5)\n",
      "Requirement already satisfied: executing in /opt/conda/lib/python3.8/site-packages (from stack-data->ipython>=4.0.0->ipywidgets->nemo_toolkit[all]==1.5.1) (0.8.3)\n",
      "Requirement already satisfied: pathtools in /opt/conda/lib/python3.8/site-packages (from wandb->nemo_toolkit[all]==1.5.1) (0.1.2)\n",
      "Requirement already satisfied: setproctitle in /opt/conda/lib/python3.8/site-packages (from wandb->nemo_toolkit[all]==1.5.1) (1.2.3)\n",
      "Requirement already satisfied: GitPython>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from wandb->nemo_toolkit[all]==1.5.1) (3.1.27)\n",
      "Requirement already satisfied: sentry-sdk>=1.0.0 in /opt/conda/lib/python3.8/site-packages (from wandb->nemo_toolkit[all]==1.5.1) (1.5.10)\n",
      "Requirement already satisfied: shortuuid>=0.5.0 in /opt/conda/lib/python3.8/site-packages (from wandb->nemo_toolkit[all]==1.5.1) (1.0.8)\n",
      "Requirement already satisfied: docker-pycreds>=0.4.0 in /opt/conda/lib/python3.8/site-packages (from wandb->nemo_toolkit[all]==1.5.1) (0.4.0)\n",
      "Requirement already satisfied: promise<3,>=2.0 in /opt/conda/lib/python3.8/site-packages (from wandb->nemo_toolkit[all]==1.5.1) (2.3)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: gitdb<5,>=4.0.1 in /opt/conda/lib/python3.8/site-packages (from GitPython>=1.0.0->wandb->nemo_toolkit[all]==1.5.1) (4.0.9)\n",
      "Requirement already satisfied: smmap<6,>=3.0.1 in /opt/conda/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython>=1.0.0->wandb->nemo_toolkit[all]==1.5.1) (5.0.0)\n",
      "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "# Install dependencies\n",
    "!pip install wget\n",
    "!apt-get update && DEBIAN_FRONTEND=noninteractive  apt-get install -y libsndfile1 ffmpeg\n",
    "!pip install Cython\n",
    "\n",
    "## Install NeMo\n",
    "!pip install nemo_toolkit[all]==1.5.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"2\"></a>\n",
    "## 2. Download Citrinet model\n",
    "\n",
    "Next, we download a pretrained Nemo Citrinet model and convert it to a Torchscript module:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nemo\n",
    "import torch\n",
    "\n",
    "import nemo.collections.asr as nemo_asr\n",
    "from nemo.core import typecheck\n",
    "typecheck.set_typecheck_enabled(False) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading and saving stt_en_citrinet_256...\n",
      "[NeMo I 2022-04-21 23:12:45 cloud:56] Found existing object /root/.cache/torch/NeMo/NeMo_1.5.1/stt_en_citrinet_256/91a9cc5850784b2065e8a0aa3d526fd9/stt_en_citrinet_256.nemo.\n",
      "[NeMo I 2022-04-21 23:12:45 cloud:62] Re-using file from: /root/.cache/torch/NeMo/NeMo_1.5.1/stt_en_citrinet_256/91a9cc5850784b2065e8a0aa3d526fd9/stt_en_citrinet_256.nemo\n",
      "[NeMo I 2022-04-21 23:12:45 common:728] Instantiating model from pre-trained checkpoint\n",
      "[NeMo I 2022-04-21 23:12:46 mixins:146] Tokenizer SentencePieceTokenizer initialized with 1024 tokens\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NeMo W 2022-04-21 23:12:47 modelPT:130] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n",
      "    Train config : \n",
      "    manifest_filepath: null\n",
      "    sample_rate: 16000\n",
      "    batch_size: 32\n",
      "    trim_silence: true\n",
      "    max_duration: 16.7\n",
      "    shuffle: true\n",
      "    is_tarred: false\n",
      "    tarred_audio_filepaths: null\n",
      "    use_start_end_token: false\n",
      "    \n",
      "[NeMo W 2022-04-21 23:12:47 modelPT:137] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n",
      "    Validation config : \n",
      "    manifest_filepath: null\n",
      "    sample_rate: 16000\n",
      "    batch_size: 32\n",
      "    shuffle: false\n",
      "    use_start_end_token: false\n",
      "    \n",
      "[NeMo W 2022-04-21 23:12:47 modelPT:143] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n",
      "    Test config : \n",
      "    manifest_filepath: null\n",
      "    sample_rate: 16000\n",
      "    batch_size: 32\n",
      "    shuffle: false\n",
      "    use_start_end_token: false\n",
      "    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[NeMo I 2022-04-21 23:12:47 features:265] PADDING: 16\n",
      "[NeMo I 2022-04-21 23:12:47 features:282] STFT using torch\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NeMo W 2022-04-21 23:12:47 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/nemo/collections/asr/parts/preprocessing/features.py:315: FutureWarning: Pass sr=16000, n_fft=512 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n",
      "      librosa.filters.mel(sample_rate, self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq), dtype=torch.float\n",
      "    \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[NeMo I 2022-04-21 23:12:49 save_restore_connector:149] Model EncDecCTCModelBPE was successfully restored from /root/.cache/torch/NeMo/NeMo_1.5.1/stt_en_citrinet_256/91a9cc5850784b2065e8a0aa3d526fd9/stt_en_citrinet_256.nemo.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[NeMo W 2022-04-21 23:12:49 export_utils:198] Swapped 0 modules\n",
      "[NeMo W 2022-04-21 23:12:49 conv_asr:73] Turned off 235 masked convolutions\n",
      "[NeMo W 2022-04-21 23:12:49 export_utils:198] Swapped 0 modules\n",
      "[NeMo W 2022-04-21 23:12:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py:916: UserWarning: `optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead\n",
      "      warnings.warn(\n",
      "    \n",
      "[NeMo W 2022-04-21 23:12:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/torch/_jit_internal.py:668: LightningDeprecationWarning: The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7. Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.\n",
      "      if hasattr(mod, name):\n",
      "    \n",
      "[NeMo W 2022-04-21 23:12:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/torch/_jit_internal.py:669: LightningDeprecationWarning: The `LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7. Please use the `pytorch_lightning.utilities.memory.get_model_size_mb`.\n",
      "      item = getattr(mod, name)\n",
      "    \n",
      "[NeMo W 2022-04-21 23:12:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/torch/_jit_internal.py:668: LightningDeprecationWarning: `LightningModule.use_amp` was deprecated in v1.6 and will be removed in v1.8. Please use `Trainer.amp_backend`.\n",
      "      if hasattr(mod, name):\n",
      "    \n",
      "[NeMo W 2022-04-21 23:12:50 nemo_logging:349] /opt/conda/lib/python3.8/site-packages/torch/_jit_internal.py:669: LightningDeprecationWarning: `LightningModule.use_amp` was deprecated in v1.6 and will be removed in v1.8. Please use `Trainer.amp_backend`.\n",
      "      item = getattr(mod, name)\n",
      "    \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(['stt_en_citrinet_256.ts'],\n",
       " ['nemo.collections.asr.models.ctc_bpe_models.EncDecCTCModelBPE exported to ONNX'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "variant = 'stt_en_citrinet_256'\n",
    "\n",
    "print(f\"Downloading and saving {variant}...\")\n",
    "asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(model_name=variant)\n",
    "asr_model.export(f\"{variant}.ts\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Benchmark utility\n",
    "\n",
    "Let us define a helper benchmarking function, then benchmark the original Pytorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model: stt_en_citrinet_256.ts\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256.ts =================================\n",
      "batch size=1, num iterations=50\n",
      "  Median samples/s: 102.0, mean: 102.0\n",
      "  Median latency (s): 0.009802, mean: 0.009803, 99th_p: 0.009836, std_dev: 0.000014\n",
      "\n",
      "Loading model: stt_en_citrinet_256.ts\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256.ts =================================\n",
      "batch size=8, num iterations=50\n",
      "  Median samples/s: 429.1, mean: 429.1\n",
      "  Median latency (s): 0.018642, mean: 0.018643, 99th_p: 0.018670, std_dev: 0.000014\n",
      "\n",
      "Loading model: stt_en_citrinet_256.ts\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256.ts =================================\n",
      "batch size=32, num iterations=50\n",
      "  Median samples/s: 551.3, mean: 551.2\n",
      "  Median latency (s): 0.058047, mean: 0.058053, 99th_p: 0.058375, std_dev: 0.000106\n",
      "\n",
      "Loading model: stt_en_citrinet_256.ts\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256.ts =================================\n",
      "batch size=128, num iterations=50\n",
      "  Median samples/s: 594.1, mean: 594.1\n",
      "  Median latency (s): 0.215434, mean: 0.215446, 99th_p: 0.215806, std_dev: 0.000116\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from __future__ import print_function\n",
    "from __future__ import absolute_import\n",
    "from __future__ import division\n",
    "\n",
    "import argparse\n",
    "import timeit\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch_tensorrt as trtorch\n",
    "import torch.backends.cudnn as cudnn\n",
    "\n",
    "def benchmark(model, input_tensor, num_loops, model_name, batch_size):\n",
    "    def timeGraph(model, input_tensor, num_loops):\n",
    "        print(\"Warm up ...\")\n",
    "        with torch.no_grad():\n",
    "            for _ in range(20):\n",
    "                features = model(input_tensor)\n",
    "\n",
    "        torch.cuda.synchronize()\n",
    "        print(\"Start timing ...\")\n",
    "        timings = []\n",
    "        with torch.no_grad():\n",
    "            for i in range(num_loops):\n",
    "                start_time = timeit.default_timer()\n",
    "                features = model(input_tensor)\n",
    "                torch.cuda.synchronize()\n",
    "                end_time = timeit.default_timer()\n",
    "                timings.append(end_time - start_time)\n",
    "                # print(\"Iteration {}: {:.6f} s\".format(i, end_time - start_time))\n",
    "        return timings\n",
    "    def printStats(graphName, timings, batch_size):\n",
    "        times = np.array(timings)\n",
    "        steps = len(times)\n",
    "        speeds = batch_size / times\n",
    "        time_mean = np.mean(times)\n",
    "        time_med = np.median(times)\n",
    "        time_99th = np.percentile(times, 99)\n",
    "        time_std = np.std(times, ddof=0)\n",
    "        speed_mean = np.mean(speeds)\n",
    "        speed_med = np.median(speeds)\n",
    "        msg = (\"\\n%s =================================\\n\"\n",
    "                \"batch size=%d, num iterations=%d\\n\"\n",
    "                \"  Median samples/s: %.1f, mean: %.1f\\n\"\n",
    "                \"  Median latency (s): %.6f, mean: %.6f, 99th_p: %.6f, std_dev: %.6f\\n\"\n",
    "                ) % (graphName,\n",
    "                    batch_size, steps,\n",
    "                    speed_med, speed_mean,\n",
    "                    time_med, time_mean, time_99th, time_std)\n",
    "        print(msg)\n",
    "    timings = timeGraph(model, input_tensor, num_loops)\n",
    "    printStats(model_name, timings, batch_size)\n",
    "\n",
    "precisions_str = 'fp32' # Precision (default=fp32, fp16)\n",
    "variant = 'stt_en_citrinet_256' # Nemo Citrinet variant\n",
    "batch_sizes = [1, 8, 32, 128] # Batch sizes (default=1,8,32,128)\n",
    "trt = False # If True, infer with Torch-TensorRT engine. Else, infer with Pytorch model.\n",
    "precision = torch.float32 if precisions_str =='fp32' else torch.float16\n",
    "\n",
    "for batch_size in batch_sizes:\n",
    "    if trt:\n",
    "        model_name = f\"{variant}_bs{batch_size}_{precision}.torch-tensorrt\"\n",
    "    else:\n",
    "        model_name = f\"{variant}.ts\"\n",
    "\n",
    "    print(f\"Loading model: {model_name}\") \n",
    "    # Load traced model to CPU first\n",
    "    model = torch.jit.load(model_name).cuda()\n",
    "    cudnn.benchmark = True\n",
    "    # Create random input tensor of certain size\n",
    "    torch.manual_seed(12345)\n",
    "    input_shape=(batch_size, 80, 1488)\n",
    "    input_tensor = torch.randn(input_shape).cuda()\n",
    "\n",
    "    # Timing graph inference\n",
    "    benchmark(model, input_tensor, 50, model_name, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Confirming the GPU we are using here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Thu Apr 21 23:13:32 2022       \n",
      "+-----------------------------------------------------------------------------+\n",
      "| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |\n",
      "|-------------------------------+----------------------+----------------------+\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
      "|                               |                      |               MIG M. |\n",
      "|===============================+======================+======================|\n",
      "|   0  NVIDIA TITAN V      On   | 00000000:17:00.0 Off |                  N/A |\n",
      "| 38%   55C    P2    42W / 250W |   2462MiB / 12288MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "|   1  NVIDIA TITAN V      On   | 00000000:65:00.0 Off |                  N/A |\n",
      "| 28%   39C    P8    26W / 250W |    112MiB / 12288MiB |      0%      Default |\n",
      "|                               |                      |                  N/A |\n",
      "+-------------------------------+----------------------+----------------------+\n",
      "                                                                               \n",
      "+-----------------------------------------------------------------------------+\n",
      "| Processes:                                                                  |\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
      "|        ID   ID                                                   Usage      |\n",
      "|=============================================================================|\n",
      "|    0   N/A  N/A      3909      G                                       4MiB |\n",
      "|    0   N/A  N/A      6047      C                                    2453MiB |\n",
      "|    1   N/A  N/A      3909      G                                      39MiB |\n",
      "|    1   N/A  N/A      4161      G                                      67MiB |\n",
      "+-----------------------------------------------------------------------------+\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"3\"></a>\n",
    "## 3. Create Torch-TensorRT modules\n",
    "\n",
    "In this step, we optimize the Citrinet Torchscript module with Torch-TensorRT with various precisions and batch sizes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generating Torchscript-TensorRT module for batchsize 1 precision torch.float32\n",
      "Generating Torchscript-TensorRT module for batchsize 8 precision torch.float32\n",
      "Generating Torchscript-TensorRT module for batchsize 32 precision torch.float32\n",
      "Generating Torchscript-TensorRT module for batchsize 128 precision torch.float32\n",
      "Generating Torchscript-TensorRT module for batchsize 1 precision torch.float16\n",
      "Generating Torchscript-TensorRT module for batchsize 8 precision torch.float16\n",
      "Generating Torchscript-TensorRT module for batchsize 32 precision torch.float16\n",
      "Generating Torchscript-TensorRT module for batchsize 128 precision torch.float16\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch_tensorrt as torchtrt\n",
    "import argparse\n",
    "\n",
    "variant = \"stt_en_citrinet_256\"\n",
    "precisions = [torch.float, torch.half]\n",
    "batch_sizes = [1,8,32,128]\n",
    "\n",
    "model = torch.jit.load(f\"{variant}.ts\")\n",
    "\n",
    "for precision in precisions:\n",
    "    for batch_size in batch_sizes:\n",
    "        compile_settings = {\n",
    "            \"inputs\": [torchtrt.Input(shape=[batch_size, 80, 1488])],\n",
    "            \"enabled_precisions\": {precision},\n",
    "            \"workspace_size\": 2000000000,\n",
    "            \"truncate_long_and_double\": True,\n",
    "        }\n",
    "        print(f\"Generating Torchscript-TensorRT module for batchsize {batch_size} precision {precision}\")\n",
    "        trt_ts_module = torchtrt.compile(model, **compile_settings)\n",
    "        torch.jit.save(trt_ts_module, f\"{variant}_bs{batch_size}_{precision}.torch-tensorrt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"4\"></a>\n",
    "## 4. Benchmark Torch-TensorRT models\n",
    "\n",
    "Finally, we are ready to benchmark the Torch-TensorRT optimized Citrinet models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FP32 (single precision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model: stt_en_citrinet_256_bs1_torch.float32.torch-tensorrt\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256_bs1_torch.float32.torch-tensorrt =================================\n",
      "batch size=1, num iterations=50\n",
      "  Median samples/s: 242.2, mean: 218.0\n",
      "  Median latency (s): 0.004128, mean: 0.004825, 99th_p: 0.008071, std_dev: 0.001270\n",
      "\n",
      "Loading model: stt_en_citrinet_256_bs8_torch.float32.torch-tensorrt\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256_bs8_torch.float32.torch-tensorrt =================================\n",
      "batch size=8, num iterations=50\n",
      "  Median samples/s: 729.9, mean: 709.0\n",
      "  Median latency (s): 0.010961, mean: 0.011388, 99th_p: 0.016114, std_dev: 0.001256\n",
      "\n",
      "Loading model: stt_en_citrinet_256_bs32_torch.float32.torch-tensorrt\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256_bs32_torch.float32.torch-tensorrt =================================\n",
      "batch size=32, num iterations=50\n",
      "  Median samples/s: 955.6, mean: 953.4\n",
      "  Median latency (s): 0.033488, mean: 0.033572, 99th_p: 0.035722, std_dev: 0.000545\n",
      "\n",
      "Loading model: stt_en_citrinet_256_bs128_torch.float32.torch-tensorrt\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256_bs128_torch.float32.torch-tensorrt =================================\n",
      "batch size=128, num iterations=50\n",
      "  Median samples/s: 1065.8, mean: 1069.4\n",
      "  Median latency (s): 0.120097, mean: 0.119708, 99th_p: 0.121618, std_dev: 0.001260\n",
      "\n"
     ]
    }
   ],
   "source": [
    "precisions_str = 'fp32' # Precision (default=fp32, fp16)\n",
    "batch_sizes = [1, 8, 32, 128] # Batch sizes (default=1,8,32,128)\n",
    "precision = torch.float32 if precisions_str =='fp32' else torch.float16\n",
    "trt = True\n",
    "\n",
    "for batch_size in batch_sizes:\n",
    "    if trt:\n",
    "        model_name = f\"{variant}_bs{batch_size}_{precision}.torch-tensorrt\"\n",
    "    else:\n",
    "        model_name = f\"{variant}.ts\"\n",
    "\n",
    "    print(f\"Loading model: {model_name}\") \n",
    "    # Load traced model to CPU first\n",
    "    model = torch.jit.load(model_name).cuda()\n",
    "    cudnn.benchmark = True\n",
    "    # Create random input tensor of certain size\n",
    "    torch.manual_seed(12345)\n",
    "    input_shape=(batch_size, 80, 1488)\n",
    "    input_tensor = torch.randn(input_shape).cuda()\n",
    "\n",
    "    # Timing graph inference\n",
    "    benchmark(model, input_tensor, 50, model_name, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### FP16 (half precision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model: stt_en_citrinet_256_bs1_torch.float16.torch-tensorrt\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256_bs1_torch.float16.torch-tensorrt =================================\n",
      "batch size=1, num iterations=50\n",
      "  Median samples/s: 288.9, mean: 272.9\n",
      "  Median latency (s): 0.003462, mean: 0.003774, 99th_p: 0.006846, std_dev: 0.000820\n",
      "\n",
      "Loading model: stt_en_citrinet_256_bs8_torch.float16.torch-tensorrt\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256_bs8_torch.float16.torch-tensorrt =================================\n",
      "batch size=8, num iterations=50\n",
      "  Median samples/s: 1201.0, mean: 1190.9\n",
      "  Median latency (s): 0.006661, mean: 0.006733, 99th_p: 0.008453, std_dev: 0.000368\n",
      "\n",
      "Loading model: stt_en_citrinet_256_bs32_torch.float16.torch-tensorrt\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256_bs32_torch.float16.torch-tensorrt =================================\n",
      "batch size=32, num iterations=50\n",
      "  Median samples/s: 1538.2, mean: 1516.4\n",
      "  Median latency (s): 0.020804, mean: 0.021143, 99th_p: 0.024492, std_dev: 0.000973\n",
      "\n",
      "Loading model: stt_en_citrinet_256_bs128_torch.float16.torch-tensorrt\n",
      "Warm up ...\n",
      "Start timing ...\n",
      "\n",
      "stt_en_citrinet_256_bs128_torch.float16.torch-tensorrt =================================\n",
      "batch size=128, num iterations=50\n",
      "  Median samples/s: 1792.0, mean: 1777.0\n",
      "  Median latency (s): 0.071428, mean: 0.072057, 99th_p: 0.076796, std_dev: 0.001351\n",
      "\n"
     ]
    }
   ],
   "source": [
    "precisions_str = 'fp16' # Precision (default=fp32, fp16)\n",
    "batch_sizes = [1, 8, 32, 128] # Batch sizes (default=1,8,32,128)\n",
    "precision = torch.float32 if precisions_str =='fp32' else torch.float16\n",
    "\n",
    "for batch_size in batch_sizes:\n",
    "    if trt:\n",
    "        model_name = f\"{variant}_bs{batch_size}_{precision}.torch-tensorrt\"\n",
    "    else:\n",
    "        model_name = f\"{variant}.ts\"\n",
    "\n",
    "    print(f\"Loading model: {model_name}\") \n",
    "    # Load traced model to CPU first\n",
    "    model = torch.jit.load(model_name).cuda()\n",
    "    cudnn.benchmark = True\n",
    "    # Create random input tensor of certain size\n",
    "    torch.manual_seed(12345)\n",
    "    input_shape=(batch_size, 80, 1488)\n",
    "    input_tensor = torch.randn(input_shape).cuda()\n",
    "\n",
    "    # Timing graph inference\n",
    "    benchmark(model, input_tensor, 50, model_name, batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"5\"></a>\n",
    "## 5. Conclusion\n",
    "\n",
    "In this notebook, we have walked through the complete process of optimizing the Citrinet model with Torch-TensorRT. On an A100 GPU, with Torch-TensorRT, we observe a speedup of ~**2.4X** with FP32, and ~**2.9X** with FP16 at batchsize of 128.\n",
    "\n",
    "### What's next\n",
    "Now it's time to try Torch-TensorRT on your own model. Fill out issues at https://github.com/NVIDIA/Torch-TensorRT. Your involvement will help future development of Torch-TensorRT.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
